--- redirect_from: - "/03code/figure5" interact_link: content/03Code/Figure5.ipynb kernel_name: python3 kernel_path: content/03Code has_widgets: false title: |- Fig 5. Reevaluation of published results pagenum: 6 prev_page: url: /03Code/Figure4.html next_page: url: /03Code/Figure6.html suffix: .ipynb search: comment: "***PROGRAMMATICALLY GENERATED, DO NOT EDIT. SEE ORIGINAL FILES IN /content***" ---
Fig 5. Reevaluation of published results
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from scipy.stats import hypergeom
from scipy.stats import binned_statistic as binsta
from scipy.special import logsumexp
from util import *
import palettable as pal
clrx = pal.cartocolors.qualitative.Prism_10.mpl_colors
clr = tuple(x for n,x in enumerate(clrx) if n in [1,2,4,5,6])
clr2 = pal.cartocolors.sequential.agSunset_7.mpl_colors
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches

import plotly.graph_objects as go
import plotly.tools as tls
from plotly.offline import plot, iplot, init_notebook_mode
from IPython.core.display import display, HTML
init_notebook_mode(connected = True)
config={'showLink': False, 'displayModeBar': False}
# CCP, the Coupon Collector's Problem
def ccp_sample(c,pool=60):
    return len(set(np.random.choice(pool,c)))

# Draw overlap
def nab_sample(s,na,nb,pool=60):
    sa = np.random.hypergeometric(s,pool-s,na)
    nab = np.random.hypergeometric(sa,pool-sa,nb)
    return nab

# Overlap between two PCRs of depth c and overlap s
def pcr_sample(c,s):
    na = ccp_sample(c)
    nb = ccp_sample(c)
    return nab_sample(s,na,nb),na,nb

# Draw na and nb samples from two populations of size pool_a and pool_b, with true overlap s
# and return empirical overlap between na and nb
# note that this is basically the same as nab_sample, but with two different size pools!
def nab_sample_unequal(s,na,nb,pool_a,pool_b):
    sa = np.random.hypergeometric(s,pool_a-s,na)
    nab = np.random.hypergeometric(sa,pool_b-sa,nb)
    return nab


def p_ccp(c, pool=60):
    p = np.zeros([c+1,pool+1])
    p[0,0] = 1;
    for row in range(1,c+1):
        for k in range(1,np.min([row+2,pool+1])):
            p[row,k] = p[row-1,k]*k/pool + p[row-1,k-1]*(1-(k-1)/pool)
    return p[-1,:]

def p_overlap(na,nb,nab,pool=60):
    p_s = np.zeros(pool+1)
    # reference: hypergeom.pmf(outcome, Total, hits, Draws, loc=0)
    for s in np.arange(pool+1):
        # p_sa is the probability that we'd get sa from the overlap (s), just in na draws of a
        p_sa = hypergeom.pmf(np.arange(pool+1),pool,s,na)
        # p_nab_given_sa is the probability of getting that nab, given sa
        p_nab_given_sa = hypergeom.pmf(nab,pool,np.arange(pool+1),nb)
        p_s[s] = np.dot(p_sa,p_nab_given_sa)
    return p_s/np.sum(p_s)

def e_overlap(na,nb,nab,pool=60):
    p_s = p_overlap(na,nb,nab,pool=pool)
    return np.dot(np.arange(pool+1),p_s)


def credible_interval(na,nb,nab,pct=90,pool=60):
    p_s = p_overlap(na,nb,nab,pool=pool)
    cdf = np.cumsum(p_s)
    ccdf = np.flipud(np.cumsum(np.flipud(p_s)))
    # adjust for fractions vs percents; put everything as a fraction
    if pct > 1:
        pct = pct/100
    cutoff = (1-pct)/2
    # get the lower bound. 
    # it's the first index at which cdf ≥ cutoff
    try:
        lower = np.where(cdf >= cutoff)[0][0]
    except IndexError:
        lower = 0
    # get the upper bound
    # it's the first index at which ccdf ≥ 0.05
    try:
        upper = np.where(ccdf >= cutoff)[0][-1]
    except IndexError:
        upper=pool
    expectation = np.dot(np.arange(pool+1),p_s)
    # Sanity and indexing check: uncomment this line to see true tail probability ≤ 0.05
    # print([cdf[lower-1],(1-ccdf[upper+1])])
    return lower,expectation,upper


def p_nab_given_c(s,c,pool=60):
    pna = p_ccp(c)
    pnb = p_ccp(c)
    nas = np.arange(1,len(pna))
    nbs = np.arange(1,len(pnb))
    p_gen = np.zeros([pool+1,pool+1,pool+1])
    for na in nas:
        p_sa = hypergeom.pmf(np.arange(pool+1),pool,s,na)
        for nb in nbs:
            pna_pnb = pna[na] * pnb[nb]
            for nab in range(0,np.minimum(na,nb)):
                p_nab_given_sa = hypergeom.pmf(nab,pool,np.arange(pool+1),nb)
                p_nab_given_s = np.dot(p_sa,p_nab_given_sa)
                p_gen[na,nb,nab] = p_nab_given_s * pna_pnb
    return p_gen

def p_shat_given_sc(s,c,shat,pool=60):
    masses = p_nab_given_c(s,c,pool=pool)
    if np.sum(masses)<0.99:
        print('Swapping to Monte Carlo')
        return p_shat_given_sc_montecarlo(s,c,shat,pool=pool)
    hist = binsta(np.ravel(shat),np.ravel(masses),statistic='sum',bins=(np.arange(pool+2)-0.5))
    return hist

def p_shat_given_sc_montecarlo(s,c,shat,pool=60,n_mc=int(1e5)):
    masses = np.zeros([pool+1,pool+1,pool+1])
    for ii in range(n_mc):
        nab,na,nb = pcr_sample(c,s)
        masses[na,nb,nab] += 1
    hist = binsta(np.ravel(shat),np.ravel(masses/n_mc),statistic='sum',bins=(np.arange(pool+2)-0.5))
    return hist

def compute_all_estimates(pool=60):
    shat = np.zeros([pool+1,pool+1,pool+1])
    for na in range(1,pool+1):
        for nb in range(1,pool+1):
            for nab in range(0,np.minimum(na+1,nb+1)):
                shat[na,nb,nab] = e_overlap(na,nb,nab,pool=pool)
    return shat

def p_overlap_unequal(na,nb,nab,pool_a,pool_b):
    # all loops are in terms of pool_a, which is assumed to be ≤ pool_b. 
    p_s = np.zeros(pool_a+1)
    # reference: hypergeom.pmf(outcome, Total, hits, Draws, loc=0)
    for s in np.arange(pool_a+1):
        # p_sa is the probability that we'd get sa from the overlap (s), just in na draws of a
        p_sa = hypergeom.pmf(np.arange(pool_a+1),pool_a,s,na)
        # p_nab_given_sa is the probability of getting that nab, given sa
        p_nab_given_sa = hypergeom.pmf(nab,pool_b,np.arange(pool_a+1),nb)
        p_s[s] = np.dot(p_sa,p_nab_given_sa)
    return p_s/np.sum(p_s)

def e_overlap_unequal(na,nb,nab,pool_a,pool_b):
    # TODO. Code expects that pool_b > pool_a...
    p_s = p_overlap_unequal(na,nb,nab,pool_a,pool_b)
    return np.dot(np.arange(pool_a+1),p_s)

# shat = compute_all_estimates(pool=60)
# np.save('shat_60.npy',shat)
shat = np.load('shat_60.npy')
bro = {}
pts = {}
lci = {}

def lenci(na,nb,nab):
    lower,expectation,upper = credible_interval(na,nb,nab,pool=60)
    return upper-lower+1


dfAme_all = pd.read_excel('data/amele.XLS').replace('-',0)
dfAme = dfAme_all.loc[dfAme_all["source"]=="AMELE"]
dfAme = dfAme.iloc[:,:len(dfAme)]

n = len(dfAme)-1
nab = dfAme.values.astype(int)
na = np.diagonal(nab)
dummy = []
dummz = []
dummp = []
for a in np.arange(n-1):
    for b in np.arange(a+1,n):
        dummy.append(shat[na[a],na[b],nab[a,b]])
        dummz.append(60*2*nab[a,b]/(na[a]+na[b]))
        dummp.append(lenci(na[a],na[b],nab[a,b]))
bro["ame"] = np.array(dummy)
pts["ame"] = np.array(dummz)
lci["ame"] = np.array(dummp)
print("Mean samples per parasite: {:.1f}".format(np.mean(na[:n])))


dfKil = pd.read_excel('data/kilifi.xlsx').replace('-',0)
n = len(dfKil)-1
na = np.array(dfKil.loc[n].apply(lambda x: 60 if x>60 else x)).astype(int)
nab = dfKil.values.astype(int)
dummy = []
dummz = []
dummp = []
# note that this file needs a and b swapped to get the lower diagonal
for a in np.arange(1,n):
    for b in np.arange(0,a):
        dummy.append(shat[na[a],na[b],nab[a,b]])
        dummz.append(60*2*nab[a,b]/(na[a]+na[b])) 
        dummp.append(lenci(na[a],na[b],nab[a,b]))
bro["kilifi"] = np.array(dummy)
pts["kilifi"] = np.array(dummz)
lci["kilifi"] = np.array(dummp)

dfBra = pd.read_excel('data/allbrazil.xlsx').replace('-',0)
n = len(dfBra)-1
na = np.array(dfBra.loc[n].apply(lambda x: 60 if x>60 else x)).astype(int)
nab = dfBra.values.astype(int)
dummy = []
dummz = []
dummp = []
for a in np.arange(n-1):
    for b in np.arange(a+1,n):
        dummy.append(shat[na[a],na[b],nab[a,b]])
        dummz.append(60*2*nab[a,b]/(na[a]+na[b]))
        dummp.append(lenci(na[a],na[b],nab[a,b]))
bro["brazil"] = np.array(dummy)
pts["brazil"] = np.array(dummz)
lci["brazil"] = np.array(dummp)


pars = ["113","121","122","123","127","128","153","154"]
dfAri = pd.read_excel('data/ariquemes_1.xlsx').replace('-',0)

n = len(dfAri)-1
na = np.array(dfAri.loc[n].apply(lambda x: 60 if x>60 else x)).astype(int)
nab = dfAri.values.astype(int)

dummy = []
dummz = []
dummp = []

q = np.zeros(np.shape(nab),dtype=int)

for par in pars:
    cols = [col for col in dfAri.columns if par in col]
    ida = list(dfAri.columns).index(cols[0])
    idz = ida+len(cols)
    for a in np.arange(ida,idz):
        for b in np.arange(a+1,idz):
            q[a,b] = 1
            dummy.append(shat[na[a],na[b],nab[a,b]])
            dummz.append(60*2*nab[a,b]/(na[a]+na[b]))
            dummp.append(lenci(na[a],na[b],nab[a,b]))
bro["ari"] = np.array(dummy)
pts["ari"] = np.array(dummz)
lci["ari"] = np.array(dummp)

dfAri2 = pd.read_excel('data/ariquemes_2.xlsx').replace('-',0)
n = len(dfAri2)-1
na = np.array(dfAri2.loc[n].apply(lambda x: 60 if x>60 else x)).astype(int)
nab = dfAri2.values.astype(int)
dummy = []
dummz = []
dummp = []
for a in np.arange(n):
    for b in np.arange(a+1,n-1):
        dummy.append(shat[na[a],na[b],nab[a,b]])
        dummz.append(60*2*nab[a,b]/(na[a]+na[b]))
        dummp.append(lenci(na[a],na[b],nab[a,b]))
bro["ari2"] = np.array(dummy)
pts["ari2"] = np.array(dummz)
lci["ari2"] = np.array(dummp)

dfWos = pd.read_excel('data/tessema_wosera.xlsx')
# compute na, the number of sequences per parasite
na = dfWos.sum().values
# number of parasites
n = len(na)
# compute nab, the number of shared sequences for each pair of parasites
pars = list(dfWos)
nab = np.zeros([n,n],dtype=int)
for a,para in enumerate(pars):
    for b,parb in enumerate(pars):
        nab[a,b] = np.sum((dfWos[para]+dfWos[parb])==2)
dummy = []
dummz = []
dummp = []
for a in np.arange(n):
    for b in np.arange(a+1,n-1):
        dummy.append(shat[na[a],na[b],nab[a,b]])
        dummz.append(60*2*nab[a,b]/(na[a]+na[b]))
        dummp.append(lenci(na[a],na[b],nab[a,b]))
bro["wosera"] = np.array(dummy)
pts["wosera"] = np.array(dummz)
lci["wosera"] = np.array(dummp)


dfAme2 = pd.read_excel('data/tessema_amele.xlsx')
# compute na, the number of sequences per parasite
na = dfAme2.sum().values
# number of parasites
n = len(na)
# compute nab, the number of shared sequences for each pair of parasites
pars = list(dfAme2)
nab = np.zeros([n,n],dtype=int)
for a,para in enumerate(pars):
    for b,parb in enumerate(pars):
        nab[a,b] = np.sum((dfAme2[para]+dfAme2[parb])==2)
dummy = []
dummz = []
dummp = []
for a in np.arange(n):
    for b in np.arange(a+1,n-1):
        dummy.append(shat[na[a],na[b],nab[a,b]])
        dummz.append(60*2*nab[a,b]/(na[a]+na[b]))
        dummp.append(lenci(na[a],na[b],nab[a,b]))
bro["amele_tessema"] = np.array(dummy)
pts["amele_tessema"] = np.array(dummz)
lci["amele_tessema"] = np.array(dummp)

dfMug = pd.read_excel('data/tessema_mugil.xlsx')
# compute na, the number of sequences per parasite
na = dfMug.sum().values
# number of parasites
n = len(na)
# compute nab, the number of shared sequences for each pair of parasites
pars = list(dfMug)
nab = np.zeros([n,n],dtype=int)
for a,para in enumerate(pars):
    for b,parb in enumerate(pars):
        nab[a,b] = np.sum((dfMug[para]+dfMug[parb])==2)
dummy = []
dummz = []
dummp = []
for a in np.arange(n):
    for b in np.arange(a+1,n-1):
        dummy.append(shat[na[a],na[b],nab[a,b]])
        dummz.append(60*2*nab[a,b]/(na[a]+na[b]))
        dummp.append(lenci(na[a],na[b],nab[a,b]))
bro["mugil"] = np.array(dummy)
pts["mugil"] = np.array(dummz)
lci["mugil"] = np.array(dummp)
Mean samples per parasite: 15.6
dat = [pts,bro]
thr = [61,61,15]
sc = 1
samples2 = ["ari","ari2","brazil","ame","kilifi"]
names = ["Ariq. clones","Ariq. isol.","Brazil","Amele","Kilifi"]
x1 = np.random.randn(500)

fig7 = go.Figure()


for x in np.arange(0,5):
    fig7.add_trace(go.Box(y = dat[0][samples2[x]][lci[samples2[x]]<=thr[0]]/sc,  marker_color = "black", fillcolor = "white", name = names[x],showlegend = False, visible = True ))


for x in np.arange(0,5):
    fig7.add_trace(go.Box(y = dat[1][samples2[x]][lci[samples2[x]]<=thr[1]]/sc,  marker_color = "black", fillcolor = "white", name = names[x],showlegend = False, visible = False ))



fig7.update_layout(plot_bgcolor='rgb(255,255,255)',  
                   updatemenus=[
                    dict(
                        active=0,
                        buttons=list([
                       
                            dict(label="PTS X 60",
                             method="update",
                             args=[{"visible": 
                            [True,True,True,True, True,False, False, False, False, False]},
                               {"title": "",
                                'yaxis': {'title': 'PTS X 60', 'ticks' : 'outside', 'showline': True, 'linecolor': 'black' }}
                                  ]),
                            
                        dict(label=" Estimated overlap \u015D",
                             method="update",
                             args=[{"visible": 
                                  [ False, False, False, False,False,True,True,True,True,True]},
                                  {"title": "",
                                   'yaxis': {'title': 'Estimated overlap \u015D', 'ticks' : 'outside', 'showline': True, 'linecolor': 'black' }}
                                   
                                ]),
                        
               
            ]),
                        x=0.2,
                        y = 1.2
                        
        )
])

fig7.update_xaxes(ticks = 'outside', showline=True, linecolor='black')
fig7.update_yaxes(ticks = 'outside', showline=True, linecolor='black', title = 'PTS X 60')

plot(fig7, filename = 'plotly_figures/fig7.html', config = config)
#Binder
#iplot(fig7, filename = 'plotly_figures/fig7.html', config = config)

#ThebeLab
display(HTML('plotly_figures/fig7.html'))
ds=2
samples = ["ame","ari"]
bands = np.arange(0,46,15)
bands = np.transpose(np.array([bands,bands+15]))

layout=go.Layout(
        annotations=[
            go.layout.Annotation(
                text='CI Width (uncertanity)',
                align = 'right',
                showarrow=False,
                xref='paper',
                yref='paper',
                x=0.98,
                y=0.55,
                
                
                 
            )
        ]
    )
fig8 = go.Figure(layout = layout)
clrs = ['rgb(12, 94, 190)','rgb(117, 170, 190)', 'rgb(208, 139, 115)','rgb(167, 36, 36)']
bins1 = np.arange(0,63,ds)-0.5
for ids,sample in enumerate(samples):
    for idx,b in enumerate(bands):
        df = pd.DataFrame(data={"bro":bro[sample], "unc":lci[sample]})
        counts, bins = np.histogram(df.loc[(df["unc"]>b[0]) & (df["unc"]<=b[1]),"bro"].values, bins = bins1)
        if ids == 0: 
            fig8.add_trace(go.Bar(visible = True, y=bins, x = counts, orientation='h',width = 2,
                                  marker_color=clrs[idx], name ="{} to {}".format(b[0],b[1])))
        else:
            fig8.add_trace(go.Bar(visible = False, y=bins, x = counts, orientation='h',width = 2,
                                  marker_color=clrs[idx], name ="{} to {}".format(b[0],b[1])))

            
fig8.update_layout(plot_bgcolor='rgb(255,255,255)', width = 700, 
                   updatemenus=[
                    dict(
                        active=0,
                        buttons=list([
                       
                            dict(label="Amele",
                             method="update",
                             args=[{"visible": 
                            [True,True,True,True,False, False, False, False]},
                               {"title": "",
                                 }
                                  ]),
                            
                        dict(label=" Ariq. clones",
                             method="update",
                             args=[{"visible": 
                                  [False, False, False,False,True,True,True,True]},
                                   {"title": "",
                                    }
                                   
                                ]),
                        
               
            ]),
                        x=0.95,
                        y = 0.2
                        
        )
                    ],legend=dict(x=.75, y=.35)
                   
                  )            
            
            
fig8.update_xaxes(ticks = 'outside', showline=True, linecolor='black',range=[0, 30])
fig8.update_yaxes(ticks = 'outside', showline=True, linecolor='black', title = 'Estimated overlap \u015D')            
            
            

plot(fig8, filename = 'plotly_figures/fig8.html', config = config)

#Binder
#iplot(fig8, filename = 'plotly_figures/fig8.html', config = config)


#ThebeLab
display(HTML('plotly_figures/fig8.html'))